package server

import (
	"bytes"
	"encoding/binary"
	//"epg/ast"
	"epg/conf"
	"epg/consts"

	"github.com/pingcap/tidb/ast"
	"github.com/pingcap/tidb/mysql"
	"github.com/pingcap/tidb/parser"
	"github.com/pingcap/tidb/util"

	"epg/privilege"
	"epg/stats"
	"epg/utils/stringutil"
	"errors"
	"fmt"
	"io"
	"net"
	"sync"
	"sync/atomic"
	"time"

	"github.com/zeast/logs"
)

var (
	errCannotGetConn = errors.New("Can not get DBPool connection")
)

/*
const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag |
	mysql.ClientConnectWithDB | mysql.ClientProtocol41 |
	mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows |
	mysql.ClientMultiStatements | mysql.ClientMultiResults | mysql.ClientLocalFiles |
	mysql.ClientConnectAtts | mysql.ClientPluginAuth
*/

var defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | mysql.ClientConnectWithDB | mysql.ClientProtocol41 | mysql.ClientTransactions | mysql.ClientSecureConnection

//ComNameMapping Com Name Mapping
var ComNameMapping = map[byte]string{
	mysql.ComQuit:             "ComQuit",
	mysql.ComInitDB:           "ComInitDB",
	mysql.ComQuery:            "ComQuery",
	mysql.ComFieldList:        "ComFieldList",
	mysql.ComCreateDB:         "ComCreateDB",
	mysql.ComDropDB:           "ComDropDB",
	mysql.ComRefresh:          "ComRefresh",
	mysql.ComShutdown:         "ComShutdown",
	mysql.ComStatistics:       "ComStatistics",
	mysql.ComProcessInfo:      "ComProcessInfo",
	mysql.ComConnect:          "ComConnect",
	mysql.ComProcessKill:      "ComProcessKill",
	mysql.ComDebug:            "ComDebug",
	mysql.ComPing:             "ComPing",
	mysql.ComTime:             "ComTime",
	mysql.ComDelayedInsert:    "ComDelayedInsert",
	mysql.ComChangeUser:       "ComChangeUser",
	mysql.ComBinlogDump:       "ComBinlogDump",
	mysql.ComTableDump:        "ComTableDump",
	mysql.ComConnectOut:       "ComConnectOut",
	mysql.ComRegisterSlave:    "ComRegisterSlave",
	mysql.ComStmtPrepare:      "ComStmtPrepare",
	mysql.ComStmtExecute:      "ComStmtExecute",
	mysql.ComStmtSendLongData: "ComStmtSendLongData",
	mysql.ComStmtClose:        "ComStmtClose",
	mysql.ComStmtReset:        "ComStmtReset",
	mysql.ComSetOption:        "ComSetOption",
	mysql.ComStmtFetch:        "ComStmtFetch",
}

//Client the client connection object
type Client struct {
	mu sync.Mutex

	io        *packetIO
	buf       *bufPool
	netConn   *net.TCPConn
	stmt      map[uint32]struct{}
	txConn    *DBContext
	die       chan struct{}
	user      string //user account name
	host      string //user connect host ip
	dbname    string //user current use database name
	nodeAddr  string //current database connection node ip:addr
	connectID uint32

	salt       []byte
	status     uint16
	capability uint32
	collation  string
	sequence   uint8
	// writeTimeout time.Duration
	queryLog string

	sqlParser *parser.Parser
}

func newClient(conn *net.TCPConn) *Client {

	c := &Client{
		netConn:   conn,
		die:       make(chan struct{}),
		io:        newPacketIO(conn),
		buf:       newBufPool(),
		connectID: atomic.AddUint32(&clientBaseID, 1),
		salt:      util.RandomBuf(20),
		status:    mysql.ServerStatusAutocommit,
		stmt:      make(map[uint32]struct{}),
		sqlParser: parser.New(),
		collation: mysql.DefaultCharset,
	}
	c.host, _, _ = net.SplitHostPort(c.netConn.RemoteAddr().String())

	//set I/O timeout
	//TODO do not set timeout when read the client connection, for PDO or other driver didn't ping server.
	// c.buf.timeout = time.Duration(clientReadDeadline) * time.Second
	// c.writeTimeout = time.Duration(clientWriteTimeout)

	return c
}

//Run running the client and accept client packet.
func (c *Client) Run() {
	//new Client
	start := time.Now()
	defer func() {
		//TODO if client use a nil username, it will case count.Decr error.
		if c.user != "" {
			CountMgr.Decr(c.user)
		}
		c.netConn.Close()
		logs.Infof("[%d] Closed %s@%s/%s", c.connectID, c.user, c.Addr(), c.dbname)
	}()
	if err := c.Handshake(); err != nil {
		logs.Debugf("握手失败: %s", err)
		return
	}
	stats.Stater.Timing(fmt.Sprintf("time.%s.handshake", c.user), time.Now().Sub(start))
	logs.Infof("[%d] Opened %s@%s/%s", c.connectID, c.user, c.Addr(), c.dbname)

	for {
		select {
		case <-die:
			return
		case <-c.die:
			return
		default:
		}
		if err := c.Accept(); err != nil {
			if err != io.EOF && err != io.ErrUnexpectedEOF {
				logs.Errorf("[%d] sql 处理失败 %s", c.connectID, err)
			}
		}

	}
}

//Addr addr
func (c *Client) Addr() string {
	return c.netConn.RemoteAddr().String()
}

//Close close the client
func (c *Client) Close() {
	if c.txConn != nil {
		if c.isTx() || len(c.stmt) != 0 {
			c.txConn.Close(true)
		} else {
			c.txConn.Close(false)
		}
		c.txConn = nil
	}

	if c.buf != nil {
		c.buf.Put()
		c.buf = nil
	}

	close(c.die)
}

/******************************************************************************
*                           Handshake Process                                 *
******************************************************************************/

//Handshake handshake
func (c *Client) Handshake() error {
	if err := c.writeInitHandshake(); err != nil {
		return err
	}
	if err := c.readHandshakeRespone(); err != nil {
		c.writeError(err)
		return err
	}

	c.writeOK()
	c.io.sequence = 0
	return nil
}

func (c *Client) writeInitHandshake() error {
	data := make([]byte, 4, 128)

	// min version 10
	data = append(data, 10)
	// server version[00]
	ver := fmt.Sprintf("%v-%v", consts.ServerVersion, consts.AppName)
	data = append(data, ver...)
	data = append(data, 0)
	// connection id
	data = append(data, byte(c.connectID), byte(c.connectID>>8), byte(c.connectID>>16), byte(c.connectID>>24))
	// auth-plugin-data-part-1
	data = append(data, c.salt[0:8]...)
	// filler [00]
	data = append(data, 0)
	// capability flag lower 2 bytes, using default capability here
	data = append(data, byte(defaultCapability), byte(defaultCapability>>8))
	// charset, utf-8 default
	data = append(data, uint8(mysql.DefaultCollationID))
	//status
	data = append(data, byte(mysql.ServerStatusAutocommit), byte(mysql.ServerStatusAutocommit>>8))
	// below 13 byte may not be used
	// capability flag upper 2 bytes, using default capability here
	data = append(data, byte(defaultCapability>>16), byte(defaultCapability>>24))
	// filler [0x15], for wireshark dump, value is 0x15
	data = append(data, 0x15)
	// reserved 10 [00]
	data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
	// auth-plugin-data-part-2
	data = append(data, c.salt[8:]...)
	// filler [00]
	data = append(data, 0)
	err := c.io.writePacket(data)

	return err
}

func (c *Client) readHandshakeRespone() error {
	data, err := c.io.readPacket()
	if err != nil {
		return err
	}
	var pos = 0

	//capability
	c.capability = binary.LittleEndian.Uint32(data[:4])
	pos += 4

	//skip max packet size
	pos += 4

	//charset, skip, if you want to use another charset, use set names
	c.collation = CharsetNames[data[pos]]
	pos++

	//skip reserved 23[00]
	pos += 23

	//user name
	uname := string(data[pos : pos+bytes.IndexByte(data[pos:], 0)])
	pos += len(uname) + 1
	if uname == "" {
		return mysql.NewErr(mysql.ErrUsername)
	}

	//check the counter limit
	if err := CountMgr.Incr(uname); err != nil {
		return err
	}
	c.user = uname

	//auth length and auth
	authLen := int(data[pos])
	pos++
	if authLen == 0 {
		return mysql.NewErr(mysql.ErrAccessDenied, c.user, c.netConn.RemoteAddr().String(), "NO")
	}

	auth := data[pos : pos+authLen]
	if ok, err := privilege.AuthChecker.Auth(c.user+"@"+c.host, c.salt, auth); !ok {
		return err
	}

	pos += authLen
	if c.capability&uint32(mysql.ClientConnectWithDB) > 0 {
		if len(data[pos:]) == 0 {
			return nil
		}
		db := string(data[pos : pos+bytes.IndexByte(data[pos:], 0)])
		//pos += len(db) + 1
		c.dbname = db
		// log.Debug("conn with db: ", db)
	}
	return nil
}

func (c *Client) writeOK() error {
	data := make([]byte, 4, 32)
	data = append(data, mysql.OKHeader)
	//afactrows, insertid.
	data = appendLengthEncodedInteger(data, 0)
	data = appendLengthEncodedInteger(data, 0)

	if c.capability&uint32(mysql.ClientProtocol41) > 0 {
		data = append(data, byte(c.status), byte(c.status>>8))
		data = append(data, 0, 0)
	}

	return c.io.writePacket(data)
}
func (c *Client) writeError(e error) error {
	var m *mysql.SQLError
	var ok bool
	if m, ok = e.(*mysql.SQLError); !ok {
		m = &mysql.SQLError{Code: mysql.ErrUnknown, Message: e.Error(), State: mysql.DefaultMySQLState}
	}
	data := make([]byte, 4, 16+len(m.Message))
	data = append(data, mysql.ErrHeader)
	data = append(data, byte(m.Code), byte(m.Code>>8))
	if c.capability&uint32(mysql.ClientProtocol41) > 0 {
		data = append(data, '#')
		data = append(data, m.State...)
	}
	data = append(data, m.Message...)
	return c.io.writePacket(data)
}

func (c *Client) isTx() bool {
	return c.status&mysql.ServerStatusInTrans > 0 || c.status&mysql.ServerStatusAutocommit == 0
}

func (c *Client) isStmt() bool {
	return len(c.stmt) > 0
}

func (c *Client) grabConn() error {
	if c.dbname == "" {
		return mysql.NewErr(mysql.ErrWrongDBName, c.dbname)
	}

	if c.txConn != nil {
		return nil
	}

	db, err := NodeMgr.GetConn(c.user, c.dbname)
	if err != nil {
		return err
	}

	c.nodeAddr = db.addr()
	c.txConn = db
	return nil
}

//Accept call Accept and loop proccess the packet
func (c *Client) Accept() (err error) {
	data, err := c.io.readPacket()
	if err != nil {
		c.io.sequence = 0
		c.Close()
		logs.Debugf("[%d] readPacket err:%s, remote:%s", c.connectID, err, c.host)
		return err
	}
	err = c.dispatch(data)
	if err != nil {
		c.writeError(err)
	}

	if _, ok := err.(net.Error); ok {
		c.Close()
	}

	c.io.sequence = 0
	return err
}

/******************************************************************************
*                          Dispatch Process                                   *
******************************************************************************/
func (c *Client) dispatch(data []byte) (err error) {
	start := time.Now()
	cmd := data[0]
	switch cmd {
	case mysql.ComQuit:
		err = c.handleQuit()

	case mysql.ComQuery:
		err = c.handleQuery(data)

	case mysql.ComFieldList:
		err = c.handleFieldList(data)

	case mysql.ComInitDB:
		err = c.handleUseDB(data)

	case mysql.ComPing:
		err = c.handlePing()

	case mysql.ComStmtPrepare:
		err = c.handleStmtPrepare(data)

	case mysql.ComStmtExecute:
		err = c.handleStmtExec(data)

	case mysql.ComStmtClose:
		err = c.handleStmtClose(data)

	case mysql.ComStmtFetch,
		mysql.ComStmtReset,
		mysql.ComStmtSendLongData:
		fallthrough

	default:
		err = fmt.Errorf("unsurport cmd %v  %v", cmd, string(data))
	}

	if c.txConn != nil {
		if _, ok := err.(net.Error); ok {
			c.txConn.Close(true)
			c.txConn = nil
		} else {
			c.status = c.txConn.GetStatus()
			if !c.isTx() && len(c.stmt) == 0 {
				c.txConn.Close(false)
				c.txConn = nil
			}
		}
	}

	// stats cmd
	elasped := time.Now().Sub(start)
	cmdStr, ok := ComNameMapping[cmd]
	if !ok {
		cmdStr = "Unknown"
	}
	stats.Stater.Counter(fmt.Sprintf("cmd.%s.%s", c.user, cmdStr), 1)       // user.cmd
	stats.Stater.Timing(fmt.Sprintf("time.%s.%s", c.user, cmdStr), elasped) // time.cmd
	stats.Stater.Gauge("packet.read", uint64(len(data)))                    // packet.read
	if err != nil {
		stats.Stater.Counter(fmt.Sprintf("error.%s.%s", c.user, cmdStr), 1) //error.cmd
	}
	c.logSQL(data, elasped)
	SlowLog.log(c.user, c.dbname, c.host, c.nodeAddr, elasped, c.queryLog)
	return err
}

//----------------------------------------------------------
// dispatch handle
//----------------------------------------------------------

func (c *Client) handleQuit() (err error) {
	c.Close()
	return nil
}

func (c *Client) handlePing() (err error) {
	//nothing to do .
	c.writeOK()
	return
}

func (c *Client) handleStmtPrepare(data []byte) error {
	err := c.grabConn()
	if err != nil {
		return err
	}

	//logs.Debugf("[%d] %t, %t, %s, %s, %t, %t",c.connectID, conn == nil,c == nil, c.dbname, c.collation, conn.db == nil,conn.Conn==nil)

	if err := c.txConn.UseDB(c.dbname, c.collation); err != nil {
		return err
	}
	res, stmt, err := c.txConn.Prepare(string(data[1:]))
	if err != nil {
		return err
	}
	c.stmt[stmt.id] = struct{}{}

	err = c.writeResultPackets(res)
	return err
}

func (c *Client) handleStmtExec(data []byte) error {
	stmtID := binary.LittleEndian.Uint32(data[1:5])
	if _, ok := c.stmt[stmtID]; !ok {
		return mysql.NewErr(mysql.ErrUnknownStmtHandler, stmtID)
	}

	res, err := c.txConn.Query(data)
	if err != nil {
		return err
	}

	err = c.writeResultPackets(res)
	return err
}

func (c *Client) handleStmtClose(data []byte) error {
	stmtID := binary.LittleEndian.Uint32(data[1:5])
	if _, ok := c.stmt[stmtID];ok{
		delete(c.stmt, stmtID)
		return c.txConn.WriteCommandPacketUint32(mysql.ComStmtClose, stmtID)
	}

	return nil
}

//handleFieldList
//https://dev.mysql.com/doc/internals/en/com-field-list.html
//https://dev.mysql.com/doc/refman/5.7/en/show-columns.html
func (c *Client) handleFieldList(data []byte) error {
	err := c.grabConn()
	if err != nil {
		return err
	}
	if err := c.txConn.UseDB(c.dbname, c.collation); err != nil {
		return err
	}
	buf := make([]byte, 0, 64)
	bt := bytes.NewBuffer(buf)
	bt.WriteByte(mysql.ComQuery)
	bt.WriteString("show columns from " + stringutil.String(data[1:]))

	res, err := c.txConn.Query(bt.Bytes())
	if err != nil {
		return err
	}
	err = c.writeResultPackets(res)
	return err
}

//handleQuery
func (c *Client) handleQuery(data []byte) error {
	//TODO check the sql stmt.
	node, err := c.sqlParser.ParseOneStmt(stringutil.String(data[1:]), "", "")
	if err != nil {
		logs.Error(err)
		return err
	}
	skip, err := c.preProcQuery(node)
	if err != nil {
		logs.Error(err)
		return err
	}
	//is it need to exce the next step query or not.
	if skip {
		return nil
	}

	err = c.grabConn()
	if err != nil {
		logs.Error(err)
		return err
	}

	err = c.txConn.UseDB(c.dbname, c.collation)
	if err != nil {
		return err
	}

	var res [][]byte
	res, err = c.txConn.Query(data)
	if err != nil {
		return err
	}

	return c.writeResultPackets(res)
}

//handleUseDB
func (c *Client) handleUseDB(data []byte) error {
	return c.useDB(stringutil.String(data[1:]))
}

func (c *Client) useDB(db string) (err error) {
	if ok, err := privilege.AuthChecker.HasPrivilege(c.user+"@"+c.host, db, "", mysql.AllPriv); !ok {
		return err
	}

	//in transaction can not change db.
	if c.isTx() {
		return mysql.NewErr(mysql.ErrLockOrActiveTransaction)
	}

	//only need to set the dbname
	if err = c.writeOK(); err == nil {
		c.dbname = db
	}
	return

}

func (c *Client) logSQL(data []byte, t time.Duration) {
	if !conf.GlobalConfig.Proxy.LogSQL {
		return
	}
	cmd := data[0]
	var str = stringutil.String(data[1:])
	var querylog = false
	switch cmd {
	case mysql.ComPing:
		str = "Ping"
	case mysql.ComInitDB:
	case mysql.ComQuery:
		querylog = true
	case mysql.ComQuit:
		str = "ComQuit"
	case mysql.ComStmtPrepare:
		querylog = true
	case mysql.ComStmtClose:
		str = fmt.Sprintf("ID: %v", binary.LittleEndian.Uint32(data[1:5]))
	case mysql.ComStmtExecute:
		str = fmt.Sprintf("ID: %v", binary.LittleEndian.Uint32(data[1:5]))
	}
	if querylog {
		c.queryLog = str
	}
	logs.Debugf("[%d] %s %v %s", c.connectID, ComNameMapping[cmd], t, str)
}

//preProcQuery handle the stmt like:
// use db stmt,
// set stmt,
// set charset stmt.
// return the stmt is one of them above or not, and exec the stmt error.
func (c *Client) preProcQuery(stmt ast.StmtNode) (skip bool, err error) {
	skip = false
	switch stmt.(type) {
	default:
		err = c.checkPrivilege(stmt)
		if err != nil {
			skip = true
		}
	case *ast.SetStmt:
		st := stmt.(*ast.SetStmt)
		skip, err = c.procSetStmt(st)
	case *ast.UseStmt:
		st := stmt.(*ast.UseStmt)
		err = c.useDB(st.DBName)
		skip = true
	case *ast.RollbackStmt, *ast.CommitStmt, *ast.BeginStmt:
		skip = false
	}
	return
}

/*
func (c *Client) procSetCharsetStmt(coll, char string) (skip bool, err error) {
	skip = true
	if coll == "" && char == "" {
		//coll = mysql.DefaultCollationName
		coll = "utf8_general_ci"
		char = mysql.DefaultCharset
	}
	collate, ok1 := mysql.CollationNames[coll]
	if !ok1 {
		charID, ok2 := mysql.CharsetIDs[char]
		if !ok2 {
			err = mysql.NewErr(mysql.ErrCollationCharsetMismatch, coll, char)
			return
		}
		collate = charID
	}

	fmt.Println(collate)
	c.collation = collate
	return skip, c.writeOK()
}
*/

func (c *Client) procCharacterSet(coll string) error {
	if _, ok := mysql.Charsets[coll]; !ok {
		return mysql.NewErr(mysql.ErrCollationCharsetMismatch, coll)
	}
	c.collation = coll
	return c.writeOK()
}

//procSetStmt pre proccess the setstmt, and skip chirld query.
//for the query is most happend at connect to server.
//Do not forward to mysqlConn, it will take conn and more times.
func (c *Client) procSetStmt(st *ast.SetStmt) (skip bool, err error) {
	if len(st.Variables) == 1 {
		ss := st.Variables[0]
		//skip all set the global variables.
		if ss.IsGlobal == true {
			skip = true
			err = c.writeOK()
			return
		}
		if ok := sqlSetStmt[ss.Name]; !ok {
			return
		}
		switch ss.Name {
		case "autocommit":
			val, ok := ss.Value.(*ast.ValueExpr)
			if !ok {
				return
			}
			v := val.GetInt64()
			if v == 0 {
				c.status &^= mysql.ServerStatusAutocommit
			} else {
				c.status |= mysql.ServerStatusAutocommit
			}
			err = c.writeOK()
			skip = true

		case "SetNAMES":
			var charset = mysql.DefaultCharset
			if ve, ok := ss.Value.(*ast.ValueExpr); ok {
				charset = ve.GetString()
			}
			return true, c.procCharacterSet(charset)

		case "character_set_results", "collation_connection":
			var charset = mysql.DefaultCharset

			if ve, ok := ss.Value.(*ast.ValueExpr); ok {
				charset = ve.GetType().Charset
			} else if ve, ok := ss.Value.(*ast.ColumnNameExpr); ok {
				charset = ve.Name.String()
			}

			return true, c.procCharacterSet(charset)

		case "sql_mode":
			//skip client "set sql_mode", all the query return ok.
			err = c.writeOK()
			skip = true
		}
	}
	return
}

func (c *Client) procSelectStmt(st *ast.SelectStmt) (skip bool, err error) {
	if st.From == nil && len(st.Fields.Fields) >= 1 {
		if exp, ok := st.Fields.Fields[0].Expr.(*ast.VariableExpr); ok {
			//Info: for go mysql driver will query "select @@max_allowed_packet"
			if exp.IsSystem && exp.Name == "max_allowed_packet" {
				logs.Debug("procSelectStmt: %s", st.Text())
				//TODO packet the result data
				// data := make([][]byte, 0, 2)
				// err = c.writeResultPackets(data)
				// skip = true
			}
		}
	}
	return false, nil
}

//checkPrivilege check privilege, Only allow DML on epg,
func (c *Client) checkPrivilege(node ast.StmtNode) (err error) {
	//check privilege
	dml, ok := node.(ast.DMLNode)
	if !ok {
		err = mysql.NewErr(mysql.ErrSpecificAccessDenied, node.Text())
		return
	}
	var priv mysql.PrivilegeType
	switch dml.(type) {
	case *ast.SelectStmt, *ast.UnionStmt:
		priv = mysql.SelectPriv
	case *ast.UpdateStmt:
		priv = mysql.UpdatePriv
	case *ast.DeleteStmt:
		priv = mysql.DeletePriv
	case *ast.InsertStmt:
		priv = mysql.InsertPriv
	case *ast.ShowStmt:
		priv = mysql.ShowDBPriv
		showStmt := dml.(*ast.ShowStmt)
		//SHOW 命令白名单，以下命令的权限检查全部认为通过
		if showStmt.Tp == ast.ShowVariables || showStmt.Tp == ast.ShowCollation || showStmt.Tp == ast.ShowCharset || showStmt.Tp == ast.ShowColumns || showStmt.Tp == ast.ShowCreateTable {
			return nil
		}
	}
	if ok, err2 := privilege.AuthChecker.HasPrivilege(c.user+"@"+c.host, c.dbname, "", priv); !ok {
		err = err2
	}
	return
}
